{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "wJcYs_ERTnnI"
      },
      "source": [
        "##### Copyright 2021 The TensorFlow Authors."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "HMUDt0CiUJk9"
      },
      "outputs": [],
      "source": [
        "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "77z2OchJTk0l"
      },
      "source": [
        "# Migration examples: LoggingTensorHook and StopAtStepHook\n",
        "\n",
        "<table class=\"tfo-notebook-buttons\" align=\"left\">\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://www.tensorflow.org/guide/migrate/logging_stop_hook\">\n",
        "    <img src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" />\n",
        "    View on TensorFlow.org</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/docs/blob/master/site/en/guide/migrate/logging_stop_hook.ipynb\">\n",
        "    <img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" />\n",
        "    Run in Google Colab</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://github.com/tensorflow/docs/blob/master/site/en/guide/migrate/logging_stop_hook.ipynb\">\n",
        "    <img src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" />\n",
        "    View source on GitHub</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a href=\"https://storage.googleapis.com/tensorflow_docs/docs/site/en/guide/migrate/logging_stop_hook.ipynb\"><img src=\"https://www.tensorflow.org/images/download_logo_32px.png\" />Download notebook</a>\n",
        "  </td>\n",
        "</table>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "meUTrR4I6m1C"
      },
      "source": [
        "This notebook demonstrates how you can migrate `tf.estimator.LoggingTensorHook` and `tf.estimator.StopAtStepHook` to custom `tf.keras.callbacks.Callback` usage instead."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "YdZSoIXEbhg-"
      },
      "source": [
        "## Setup\n",
        "\n",
        "First, you need to define a couple of necessary imports."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "iE0vSfMXumKI"
      },
      "outputs": [],
      "source": [
        "import tensorflow as tf\n",
        "import tensorflow.compat.v1 as tf1"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Jsm9Rxx7s1OZ"
      },
      "source": [
        "Prepare some simple data for demonstration."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "m7rnGxsXtDkV"
      },
      "outputs": [],
      "source": [
        "features = [[1., 1.5], [2., 2.5], [3., 3.5]]\n",
        "labels = [[0.3], [0.5], [0.7]]\n",
        "\n",
        "# define input function\n",
        "def _input_fn():\n",
        "  return tf1.data.Dataset.from_tensor_slices((features, labels)).batch(1)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "4uXff1BEssdE"
      },
      "source": [
        "### TF1: Estimator.train"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "zW-X5cmzmkuw"
      },
      "source": [
        "You can use different hooks to control training behavior in TF1, and then pass all hooks to `tf.estimator.EstimatorSpec`."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "5EtQXPcqmxe5"
      },
      "source": [
        "- To monitor tensors, for example model weights or losses, you can use `tf.estimator.LoggingTensorHook` (`tf1.train.LoggingTensorHook` is its alias).\n",
        "\n",
        "- To stop training at a specific step, you can use `tf.estimator.StopAtStepHook` (`tf1.train.StopAtStepHook` is its alias)."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "lqe9obf7suIj"
      },
      "outputs": [],
      "source": [
        "def _model_fn(features, labels, mode):\n",
        "  dense = tf1.layers.Dense(1)\n",
        "  logits = dense(features)\n",
        "  loss = tf1.losses.mean_squared_error(labels=labels, predictions=logits)\n",
        "  optimizer = tf1.train.AdagradOptimizer(0.05)\n",
        "  train_op = optimizer.minimize(loss, global_step=tf1.train.get_global_step())\n",
        "\n",
        "  # build the stop hook\n",
        "  stop_hook = tf1.train.StopAtStepHook(num_steps=2)\n",
        "\n",
        "  # access tensors to be logged by names\n",
        "  kernel_name = tf.identity(dense.weights[0])\n",
        "  bias_name = tf.identity(dense.weights[1])\n",
        "  logging_weight_hook = tf1.train.LoggingTensorHook(\n",
        "      tensors=[kernel_name, bias_name],\n",
        "      every_n_iter=1)\n",
        "  # log training loss by the tensor object\n",
        "  logging_loss_hook = tf1.train.LoggingTensorHook(\n",
        "      {'loss from LoggingTensorHook': loss},\n",
        "      every_n_secs=3)\n",
        "\n",
        "  # pass all hooks to EstimatorSpec\n",
        "  return tf1.estimator.EstimatorSpec(mode,\n",
        "                                     loss=loss,\n",
        "                                     train_op=train_op,\n",
        "                                     training_hooks=[stop_hook,\n",
        "                                                     logging_weight_hook,\n",
        "                                                     logging_loss_hook])\n",
        "\n",
        "estimator = tf1.estimator.Estimator(model_fn=_model_fn)\n",
        "\n",
        "# the training will stop after 2 steps, and the weights/loss are logged as well\n",
        "estimator.train(_input_fn)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "KEmzBjfnsxwT"
      },
      "source": [
        "### TF2: Keras training API"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "839R9i4xheI5"
      },
      "source": [
        "In TF2, you can define custom callbacks for Keras training API. Check the API [docs](https://www.tensorflow.org/api_docs/python/tf/keras/callbacks/Callback) and [Writing your own callbacks](https://www.tensorflow.org/guide/keras/custom_callback) for more details.\n",
        "\n",
        "- To migrate `StopAtStepHook`, you can specify when to stop training by overriding `on_batch_end` method.\n",
        "\n",
        "- To migrate `LoggingTensorHook`, accessing to tensors by names is not supported. You need to record and output the logged tensors manually.\n",
        "You can also implement the logging frequency in the custom callback. The example below will print the weights every two steps. Other strategies like logging every n seconds are also possible."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "atVciNgPs0fw"
      },
      "outputs": [],
      "source": [
        "class StopAtStepCallback(tf.keras.callbacks.Callback):\n",
        "  def __init__(self, stop_step=None):\n",
        "    super().__init__()\n",
        "    self._stop_step = stop_step\n",
        "\n",
        "  def on_batch_end(self, batch, logs=None):\n",
        "    if self.model.optimizer.iterations >= self._stop_step:\n",
        "      self.model.stop_training = True\n",
        "      print('\\nstop training now')\n",
        "\n",
        "class LoggingTensorCallback(tf.keras.callbacks.Callback):\n",
        "  def __init__(self, every_n_iter):\n",
        "      super().__init__()\n",
        "      self._every_n_iter = every_n_iter\n",
        "      self._log_count = every_n_iter\n",
        "\n",
        "  def on_batch_end(self, batch, logs=None):\n",
        "    if self._log_count > 0:\n",
        "      self._log_count -= 1\n",
        "      print(\"Logging Tensor Callback: dense/kernel:\",\n",
        "            model.layers[0].weights[0])\n",
        "      print(\"Logging Tensor Callback: dense/bias:\",\n",
        "            model.layers[0].weights[1])\n",
        "      print(\"Logging Tensor Callback loss:\", logs[\"loss\"])\n",
        "    else:\n",
        "      self._log_count -= self._every_n_iter"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Kip65sYBlKiu"
      },
      "outputs": [],
      "source": [
        "dataset = tf.data.Dataset.from_tensor_slices((features, labels)).batch(1)\n",
        "model = tf.keras.models.Sequential([tf.keras.layers.Dense(1)])\n",
        "optimizer = tf.keras.optimizers.Adagrad(learning_rate=0.05)\n",
        "model.compile(optimizer, \"mse\")\n",
        "\n",
        "# the training will stop after 2 steps, the weights/loss are logged as well\n",
        "model.fit(dataset, callbacks=[StopAtStepCallback(stop_step=2),\n",
        "                              LoggingTensorCallback(every_n_iter=2)])"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "collapsed_sections": [],
      "name": "logging_stop_hook.ipynb",
      "provenance": [],
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
